import os, json, shutil, random, multiprocessing, warnings

import pickle
import numpy as np
import matplotlib.pyplot as plt

from src.envs.line_environment import LineEnvironment
from src.envs.alanine_environment import AlanineDipeptideEnvironment
from src.envs.pendulum_environment import PendulumEnvironment
from src.envs.grid_environment import GridEnvironment
from src.envs.hypergrid_environment import HypergridEnvironment

from src.utils.plotting import Plotter

from src.metad.metadynamics_sampler import MetadynamicsSampler

from src.replay_buffer import ReplayBuffer

from src.utils.networks import NeuralNet, MultiHeadedMLP

from src.gfn.tb_gfn import TBGFlowNet
from src.gfn.stb_gfn import STBGFlowNet
from src.gfn.db_gfn import DBGFlowNet

from tqdm import tqdm
import torch

class Experiment():
    """
    Class for running experiments.
    """

    def __init__(self, config, out = None):
        self.config = config
        self.out = out
        self._init_threads()
        self._init_exp_dir()

    def train(self):
        """
        Trains the GFN for the specified number of iterations and repeats.
        """    
        if self.config["device"] == "mps":
            if self.config["device"] == "mps":
                warnings.warn("MPS device detected. Multiprocessing is not supported on MPS. Running in single process mode.")
            for repeat in range(self.config["repeats"]):
                self._train_single(repeat)
        else:
            if self.out is None:
                num_processes = self.config["repeats"] if self.config["n_processes"] == -1 else self.config["n_processes"]
                with multiprocessing.Pool(num_processes) as pool:
                    pool.starmap(self._train_single, [(repeat,) for repeat in range(self.config["repeats"])])
            else:
                if self.config["repeats"] > 1:
                    warnings.warn("Cannot multiprocess when visualizing progress. Running in single process mode.")
                self._train_single(0)

    def _train_single(self, repeat):
        self._seed_all(repeat)
        self._setup_objs()
        save_iterations, losses, logZs, L1_policy_error = [], [], [], []
        self._init_data_files(repeat)
        
        skip_value = 0
        for iteration in tqdm(range(self.config["n_iterations"])):

            # Running trajectories
            if (iteration + 1) % self.config["freq_md"] == 0 and self.config["metad"]["active"]:
                trajs, loss = self.gfn.run_metadynamics_batch(self.mds)
                if self.config["replay_buffer"]["active"]:
                    self.rb.batch_push(trajs)
                if self.config["metad"]["train"] == False:
                    continue
            elif (iteration + 1) % self.config["freq_ns"] == 0 and self.config["gfn"]["nested_sampling"]:
                trajs, loss = self.gfn.run_nested_sampling_batch()
            elif (iteration + 1) % self.config["freq_rb"] == 0 and len(self.rb) > self.config["batch_size"] and self.config["replay_buffer"]["active"]:
                trajs, loss = self.gfn.run_replay_batch(self.rb)
            else:
                trajs, loss = self.gfn.run_batch()
                if not self.config["metad"]["active"] or (iteration + 2) % self.config["freq_md"] == 0:
                   if self.config["replay_buffer"]["active"]:
                       self.rb.batch_push(trajs)

            # Updating the model
            self.gfn.step(loss)

            # Updating data
            if (iteration) % self.config["data_saving"]["freq_data_save"] == 0:
                self._update_data_arrays(losses, logZs, L1_policy_error, save_iterations, iteration, loss)

            # Graphics and progress bar
            if self.out is not None:
                if skip_value > 0:
                    skip_value -= 1
                    continue  
                if (iteration + 1) % self.config["freq_plot"] == 0 and self.config["plot"]:
                    result = self.plotter.plot_sampler(save_iterations, self.mds, self.rb, self.out, iteration, losses, logZs, L1_policy_error)
                    if result == -1:
                        break
                    else:
                        skip_value = result

        # Save data to a new line of an appropriate file
        self._write_data_files(losses, logZs, L1_policy_error, save_iterations)
        self._save_models_and_figures(save_iterations, losses, logZs, L1_policy_error, repeat)


    def _init_threads(self):
        if self.config["n_threads"] != -1:
            torch.set_num_threads(self.config["n_threads"])

    def _init_exp_dir(self):
        """Initialises the config file and saves it to the master directory."""

        save_dir = os.path.join(self.config["master_dir"], self.config["exp_name"])
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        else:
            # delete the experiment if it already exists
            shutil.rmtree(save_dir)
            print("Experiment already exists, deleting it")
            os.mkdir(save_dir)

        # Save config using json
        f = open(os.path.join(save_dir, "config.json"), "w")
        f.write(json.dumps(self.config, indent=4))
        f.close()

    def _init_env(self):
        if self.config["env"]["env_name"] == "line":
            env = LineEnvironment(self.config)
        elif self.config["env"]["env_name"] == "alanine":
            env = AlanineDipeptideEnvironment(self.config)
        elif self.config["env"]["env_name"] == "pendulum":
            env = PendulumEnvironment(self.config)
        elif self.config["env"]["env_name"] == "grid":
            env = GridEnvironment(self.config)
        elif self.config["env"]["env_name"] == "hypergrid":
            env = HypergridEnvironment(self.config)
        else:
            raise ValueError(f"Unknown environment: {self.config['env']['env_name']}")
        
        self.env = env

    def _create_networks(self):
        input_dim = self.env.feature_dim + 1
        hidden_dim = self.config["gfn"]["hidden_dim"]
        n_hidden_layers = self.config["gfn"]["n_hidden_layers"]
        if self.config["gfn"]["thompson_sampling"]:
            forward_net = MultiHeadedMLP(input_dim=input_dim, output_dim=self.env.output_dim, n_heads=self.config["gfn"]["thompson_sampling_num_heads"], hidden_dim=hidden_dim, n_hidden_layers=n_hidden_layers)
        else:
            forward_net = NeuralNet(input_dim=input_dim, output_dim=self.env.output_dim, hidden_dim=hidden_dim, n_hidden_layers=n_hidden_layers)
        if self.config["gfn"]["tie_weights"]:
            backward_net = NeuralNet(input_dim=input_dim, output_dim=self.env.output_dim, torso=forward_net.torso)
            flow_net = NeuralNet(input_dim=input_dim, output_dim=1, torso=forward_net.torso) if self.config["gfn"]["loss"] != "TB" else None
        else:
            backward_net = NeuralNet(input_dim=input_dim, output_dim=self.env.output_dim, hidden_dim=hidden_dim, n_hidden_layers=n_hidden_layers)
            flow_net = NeuralNet(input_dim=input_dim, output_dim=1, hidden_dim=hidden_dim, n_hidden_layers=n_hidden_layers) if self.config["gfn"]["loss"] != "TB" else None
        
        return forward_net, backward_net, flow_net

    def _create_gfn(self, forward_net, backward_net, flow_net=None):
        loss_type = self.config["gfn"]["loss"]
        if loss_type == "TB":
            return TBGFlowNet(self.env, self.config, forward_model=forward_net, backward_model=backward_net, tied=self.config["gfn"]["tie_weights"])
        elif loss_type == "STB":
            return STBGFlowNet(self.env, self.config, forward_model=forward_net, backward_model=backward_net, logF_model=flow_net, tied=self.config["gfn"]["tie_weights"], lamda=self.config["gfn"]["lambda"])
        elif loss_type == "DB":
            return DBGFlowNet(self.env, self.config, forward_model=forward_net, backward_model=backward_net, logF_model=flow_net, tied=self.config["gfn"]["tie_weights"])
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")

    def _init_gfn(self):
        forward_net, backward_net, flow_net = self._create_networks()
        self.gfn = self._create_gfn(forward_net, backward_net, flow_net)

    def _setup_objs(self):
        """ Sets up the environment, gfn, metadynamics sampler and replay buffer. """
        self._init_env()
        self._init_gfn()
        self.mds = MetadynamicsSampler(self.config, self.env)
        self.rb = ReplayBuffer(self.config, self.env, self.gfn)
        self.plotter = Plotter(self.config, self.env, self.gfn)

    def _init_data_files(self, repeat):
        """Initialises the data files for saving the loss, logZ and L1 policy error."""

        save_dir = os.path.join(self.config["master_dir"], self.config["exp_name"], "repeat_" + str(repeat))
        os.mkdir(save_dir)
        if any([self.config["data_saving"]["loss"], self.config["data_saving"]["logZ"], self.config["data_saving"]["L1_policy_error"]]):
            with open(os.path.join(save_dir, "iterations.txt"), "a+") as f:
                pass
        if self.config["data_saving"]["loss"]:
            with open(os.path.join(save_dir, "loss.txt"), "a+") as f:
                pass
        if self.config["data_saving"]["logZ"]:
            with open(os.path.join(save_dir, "logZ.txt"), "a+") as f:
                pass
        if self.config["data_saving"]["L1_policy_error"]:
            with open(os.path.join(save_dir, "L1_policy_error.txt"), "a+") as f:
                pass

    def _update_data_arrays(self, losses, logZs, L1_policy_error, save_iterations, iteration, loss):
        """Updates the data arrays for saving the loss, logZ and L1 policy error."""

        if any([self.config["data_saving"]["loss"], self.config["data_saving"]["logZ"], self.config["data_saving"]["L1_policy_error"], self.config["data_saving"]["L1_potential_kde_error"]]):
            save_iterations.append(iteration)
        
        if self.config["data_saving"]["loss"]:
            losses.append(loss.item())

        if self.config["data_saving"]["logZ"]:
            if self.config["gfn"]["loss"] == "TB":
                logZs.append(self.gfn.logZ.item())
            elif self.config["gfn"]["loss"] == "STB" or self.config["gfn"]["loss"] == "DB":
                zeros_tensor = torch.zeros(self.env.dim + 1, device=self.config["device"])
                logZs.append(self.gfn.logF_model(zeros_tensor).item())

        if self.config["data_saving"]["L1_policy_error"]:
            L1_policy_error.append(self.gfn.L1_error(samples = self.config["data_saving"]["on_policy_samples"])) 

    def _write_data_files(self, losses, logZs, L1_policy_error, save_iterations):
        """Writes the data arrays to the data files."""

        save_dir = os.path.join(self.config["master_dir"], self.config["exp_name"])
        if any([self.config["data_saving"]["loss"], self.config["data_saving"]["logZ"], self.config["data_saving"]["L1_policy_error"], self.config["data_saving"]["L1_potential_kde_error"]]):
            with open(os.path.join(save_dir, "iterations.txt"), "a") as f:
                f.write(",".join(map(str, save_iterations)) + "\n")

        if self.config["data_saving"]["loss"]:
            with open(os.path.join(save_dir, "loss.txt"), "a") as f:
                f.write(",".join(map(str, losses)) + "\n")
        
        if self.config["data_saving"]["logZ"]:
            with open(os.path.join(save_dir, "logZ.txt"), "a") as f:
                f.write(",".join(map(str, logZs)) + "\n")

        if self.config["data_saving"]["L1_policy_error"]:
            with open(os.path.join(save_dir, "L1_policy_error.txt"), "a") as f:
                f.write(",".join(map(str, L1_policy_error)) + "\n")

    def _save_models_and_figures(self, iterations, losses, logZs, L1_policy_error, repeat):
        save_dir = os.path.join(self.config["master_dir"], self.config["exp_name"], "repeat_" + str(repeat))
        torch.save(self.gfn.forward_model, os.path.join(save_dir, "forward_model.pt"))
        torch.save(self.gfn.backward_model, os.path.join(save_dir, "backward_model.pt"))

        # save the replay buffer to file by pickling it
        with open(os.path.join(save_dir, "replay_buffer.pkl"), "wb") as f:
            pickle.dump(self.rb, f)

        # save the metadynamics sampler
        with open(os.path.join(save_dir, "metadynamics_sampler.pkl"), "wb") as f:
            pickle.dump(self.mds, f)

        if self.config["gfn"]["loss"] == "TB":
            torch.save(self.gfn.logZ, os.path.join(save_dir, "logZ.pt"))

        if self.config["gfn"]["loss"] == "STB" or self.config["gfn"]["loss"] == "DB":
            torch.save(self.gfn.logF_model, os.path.join(save_dir, "logF_model.pt"))

        if self.config["data_saving"]["loss"]:
            fig, ax = plt.subplots()
            self.plotter.plot_evolution(ax, iterations, losses, "loss", save_dir, log=True)
        if self.config["data_saving"]["logZ"]:
            fig, ax = plt.subplots()
            self.plotter.plot_evolution(ax, iterations, logZs, "logZ", save_dir, log=False)
        if self.config["data_saving"]["L1_policy_error"]:
            fig, ax = plt.subplots()
            self.plotter.plot_evolution(ax, iterations, L1_policy_error, "L1_policy_error", save_dir)

    def _seed_all(self, seed: int):
        """Resets the random seed of all random number generators."""
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.manual_seed(seed)